# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim as optim
from PIL import Image
import numpy as np
import json

# Torchvision
import torchvision
import torchvision.transforms as transforms
import argparse
import os
import torchmetrics
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm

from transformers import CLIPProcessor, CLIPModel
from clip_attack import get_text_sim_map
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def normalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images - mean[None, :, None, None]
    images = images / std[None, :, None, None]
    return images

def denormalize(images):
    mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
    std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()
    images = images * std[None, :, None, None]
    images = images + mean[None, :, None, None]
    return images

def conver_to_image_mask(mask,image_size=224, grid_size=16, device='cuda:0'):
    mask = mask.squeeze(0)
    row_indices = torch.arange(256) // grid_size
    col_indices = torch.arange(256) % grid_size


    image_mask = torch.zeros((image_size, image_size), dtype=torch.float32)

    for i in range(256):
        row = row_indices[i]
        col = col_indices[i]

        start_row = row * (image_size // grid_size)
        start_col = col * (image_size // grid_size)
        end_row = start_row + (image_size // grid_size)
        end_col = start_col + (image_size // grid_size)


        image_mask[start_row:end_row, start_col:end_col] = mask[i]

    return image_mask.unsqueeze(0).unsqueeze(0).to(device)

def convert_mask_to_patch_mask(mask, patch_size=14, image_size=224):

    num_patches = (image_size // patch_size) ** 2
    mask_reshaped = mask.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)

    patch_mask = mask_reshaped.sum(dim=(4, 5))
    patch_mask = (patch_mask > 0).float()
    patch_mask = patch_mask.view(1, num_patches)

    return patch_mask



model = CLIPModel.from_pretrained(" ") # add your clip path here
processor = CLIPProcessor.from_pretrained(" ") # add your clip path here
model.eval()
model.to(device)
vision_model = model.vision_model

preprocess = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=processor.image_processor.image_mean, std=processor.image_processor.image_std)
])



epsilon = 8 / 255.

with open('appear_dataset_info.json', 'r', encoding='utf-8') as file:
    data = json.load(file)


for image_path, content in tqdm(data.items()):

    img_id = image_path.split('/')[-1]
    img_id = img_id.split('.')[0]
    target_obj = content['attack_target']["category_name"]
    print(target_obj)


    print(image_path)
    image = Image.open(image_path).convert('RGB')
    image = processor(images=image, return_tensors="pt").to(device)
    image_tensor = image['pixel_values']

    target_image_path = '  '  # your added image path
    target_image = Image.open(target_image_path).convert('RGB')
    target_image = processor(images=target_image, return_tensors="pt").to(device)
    target_image_tensor = target_image['pixel_values']

    refer_image_path = '  ' # your reference image path (resized target image with 100*100)
    ref_image = Image.open(refer_image_path).convert('RGB')
    ref_image = processor(images=ref_image, return_tensors="pt").to(device)
    ref_image_tensor = ref_image['pixel_values']

    mask = torch.zeros((1, 1, 224, 224)).to(device)
    mask[:, :, 50:150, 124:] = 1   # adjust the mask according to your image




    with torch.no_grad():

        init_rand = torch.rand((1, 3, 224, 224)).clamp(-epsilon, epsilon).to(device)
        adv_noise = (denormalize(image_tensor) + init_rand).clamp(0, 1) - denormalize(image_tensor)

        image_output = vision_model(image_tensor, output_attentions=True)
        image_embs = image_output.last_hidden_state[:, 1:, :]
        image_cls_embs = image_output.pooler_output
        image_cls_embs_proj = model.visual_projection(image_cls_embs)
        image_attention = image_output.attentions
        image_attention = torch.stack(image_attention).mean(dim=0)
        image_attention = torch.mean(image_attention, dim=1)
        image_attention = image_attention[:, 1:, 0]
        image_attention = image_attention / image_attention.sum()
        image_attention = image_attention.unsqueeze(-1)

        ref_image_output = vision_model(ref_image_tensor, output_attentions=True)
        ref_image_embs = ref_image_output.last_hidden_state[:, 1:, :]
        ref_cls_embs = ref_image_output.pooler_output
        ref_cls_embs_proj = model.visual_projection(ref_cls_embs)
        ref_attention = ref_image_output.attentions
        ref_attention = torch.stack(ref_attention).mean(dim=0)
        ref_attention = torch.mean(ref_attention, dim=1)
        ref_attention = ref_attention[:, 1:, 0]
        ref_attention = ref_attention / ref_attention.sum()
        ref_attention = ref_attention.unsqueeze(-1)

        target_output = vision_model(target_image_tensor, output_attentions=True)
        target_embs = target_output.last_hidden_state[:, 1:, :]
        target_image_embs = model.visual_projection(target_embs)
        target_cls_embs = target_output.pooler_output
        target_cls_embs_proj = model.visual_projection(target_cls_embs)

        mask = torch.zeros((1, 1, 224, 224)).to(device)
        mask[:, :, :100, 124:] = 1

        patch_mask = convert_mask_to_patch_mask(mask)
        patch_mask = patch_mask.unsqueeze(-1)

        fused_cls_embs = 0.5 * image_cls_embs + 0.5 * target_cls_embs

        k = 0.4 / (patch_mask * ref_attention).sum()
        m = 0.6 / ((1 - patch_mask) * image_attention).sum()

        target_attention = k * patch_mask * ref_attention + (1 - patch_mask) * m * image_attention

    for i in range(501):
        adv_noise.requires_grad_()
        optimizer = optim.Adam([{'params': adv_noise, 'lr': 0.005}])

        adv_x = denormalize(image_tensor) + adv_noise
        adv_x = normalize(adv_x)

        adv_output = vision_model(adv_x, output_attentions=True)
        adv_embs = adv_output.last_hidden_state[:, 1:, :]
        adv_attentions = adv_output.attentions
        adv_cls_embs = adv_output.pooler_output
        adv_embs_proj = model.visual_projection(adv_embs)
        adv_cls_embs_proj = model.visual_projection(adv_cls_embs)

        adv_attention = adv_output.attentions
        adv_attention = torch.stack(adv_attention).mean(dim=0)
        adv_attention = torch.mean(adv_attention, dim=1)
        adv_attention = adv_attention[:, 1:, 0]
        adv_attention = adv_attention / adv_attention.sum()
        adv_attention = adv_attention.unsqueeze(-1)


        loss1 = 1 - F.cosine_similarity(adv_cls_embs, fused_cls_embs, dim=-1).mean()
        loss2 = 1 - F.cosine_similarity(adv_embs * adv_attention,
                                        patch_mask * ref_image_embs * target_attention, dim=-1).sum()/ patch_mask.sum()
        loss3 = 1 - F.cosine_similarity(adv_embs * adv_attention,
                                        (1 - patch_mask) * image_embs * target_attention, dim=-1).sum()/ (1 - patch_mask).sum()

        loss = loss1 + 2 * loss2 + 0.3 * loss3


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 100 ==0 :
            print(
                f'Loss1: {loss1.item():.4f}, Loss2: {loss2.item():.4f}, Loss3: {loss3.item():.4f}')  # , Loss4: {loss4.item():.4f}')



        adv_noise = adv_noise.detach().clamp(-epsilon, epsilon)
        adv_noise = (denormalize(image_tensor) + adv_noise).clamp(0, 1) - denormalize(image_tensor)

        if i % 1 == 0:

            adv_x = denormalize(image_tensor) + adv_noise
            adv_x = adv_x.clamp(0, 1)
            pil_image = transforms.functional.to_pil_image(adv_x.squeeze(0))
            pil_image.save(' ') # your save path
